

capitalize_firstLetter <- function(char) {
  
  checkmate::assert_character(char, len = 1)
  
  
  char_cap <- paste0(toupper(substr(char, 1, 1)), substr(char, 2, nchar(char)))
  
  return(char_cap)
}

plot_APCheatmap_diff <- function(dat, y_var = NULL, model1 = NULL,
                            model2 = NULL,
                            dimensions = c("period","age"), apc_range = NULL,
                            bin_heatmap = TRUE, bin_heatmapGrid_list = NULL,
                            markLines_list = NULL,
                            markLines_displayLabels = c("age","period","cohort"),
                            y_var_logScale = FALSE, plot_CI = TRUE,
                            legend_limits = NULL) {
  
  checkmate::assert_data_frame(dat)
  checkmate::assert_true(!is.null(y_var) | !is.null(model1))
  checkmate::assert_character(y_var, len = 1, null.ok = TRUE)
  checkmate::assert_choice(y_var, choices = colnames(dat), null.ok = TRUE)
  checkmate::assert_class(model1, classes = "gam", null.ok = TRUE)
  checkmate::assert_character(dimensions, len = 2)
  checkmate::assert_subset(dimensions, choices = c("age","period","cohort"))
  checkmate::assert_list(apc_range, types = "numeric", max.len = 3,
                         null.ok = TRUE, any.missing = FALSE)
  checkmate::assert_subset(names(apc_range), choices = c("age","period","cohort"))
  checkmate::assert_logical(bin_heatmap, len = 1)
  checkmate::assert_list(bin_heatmapGrid_list, min.len = 1, max.len = 3,
                         types = "numeric", null.ok = TRUE)
  checkmate::assert_subset(names(bin_heatmapGrid_list), choices = c("age","period","cohort"))
  checkmate::assert_list(markLines_list, min.len = 1, max.len = 3,
                         types = "numeric", null.ok = TRUE)
  checkmate::assert_subset(names(markLines_list), choices = c("age","period","cohort"))
  checkmate::assert_character(markLines_displayLabels, null.ok = TRUE)
  checkmate::assert_subset(markLines_displayLabels, choices = c("age","period","cohort"))
  checkmate::assert_logical(y_var_logScale, len = 1)
  checkmate::assert_logical(plot_CI, len = 1)
  checkmate::assert_numeric(legend_limits, len = 2, null.ok = TRUE)
  
  
  # some NULL definitions to appease CRAN checks regarding use of dplyr/ggplot2
  period <- age <- effect <- se <- exp_effect <- exp_se <- upper <- lower <-
    exp_upper <- exp_lower <- cohort <- x <- y <- plot_effect <- plot_lower <-
    plot_upper <- NULL
  
  
  if (!is.null(y_var)) { # plot observed structures
    
    dat <- dat %>% 
      mutate(cohort = period - age) %>% 
      dplyr::rename(effect = y_var) %>% # rename 'y_var' for easier handling
      filter(!is.na(effect))
    
    plot_dat <- dat
    
    # if 'y_var' is not binned, take the average of observations with the same
    # age and period, to prevent overplotting
    if (!bin_heatmap) {
      plot_dat <- plot_dat %>% 
        group_by(period, age) %>% 
        summarize(effect = mean(effect)) %>% 
        ungroup()
    }
    
    # create some variables and objects, to re-use the model-based code
    plot_dat <- plot_dat %>% 
      mutate(cohort = period - age,
             upper  = effect,
             lower  = effect)
    dat_predictionGrid <- plot_dat
    
    if (y_var_logScale) {
      plot_dat <- plot_dat %>% mutate(effect = log10(effect))
    }
    
    plot_CI      <- FALSE
    used_logLink <- FALSE
    legend_title <- ifelse(!y_var_logScale, paste0("average ",y_var),
                           paste0("average log10(",y_var,")"))
    y_trans      <- "identity"
    
    
  } else { # plot smoothed, model-based structures
    
    # create a dataset for predicting the values of the APC surface
    grid_age    <- min(dat$age, na.rm = TRUE):max(dat$age, na.rm = TRUE)
    grid_period <- min(dat$period, na.rm = TRUE):max(dat$period, na.rm = TRUE)
    dat_predictionGrid <- expand.grid(age    = grid_age,
                                      period = grid_period) %>% 
      mutate(cohort = period - age)
    # add random values for all further covariates in the model,
    # necessary for calling mgcv:::predict.gam
    covars <- attr(model1$terms, "term.labels")
    covars <- covars[!(covars %in% c("age","period","cohort"))]
    if (length(covars) > 0) {
      dat_predictionGrid[,covars] <- dat[1, covars]
    }
    
    # create a dataset containing the estimated values of the APC surface
    terms_model     <- sapply(model1$smooth, function(x) { x$label })
    terms_index_APC <- which(grepl("age", terms_model) | grepl("period", terms_model))
    term_APCsurface <- terms_model[terms_index_APC]
    
    prediction1 <- mgcv::predict.gam(object  = model1,
                                    newdata = dat_predictionGrid,
                                    type    = "terms",
                                    terms   = term_APCsurface,
                                    se.fit  = TRUE)
    prediction2 <- mgcv::predict.gam(object  = model2,
                                     newdata = dat_predictionGrid,
                                     type    = "terms",
                                     terms   = term_APCsurface,
                                     se.fit  = TRUE)
    
    plot_dat <- dat_predictionGrid %>%
      mutate(effect = as.vector(prediction1$fit - prediction2$fit),
             se     = as.vector(sqrt((prediction1$se.fit)^2 + (prediction2$se.fit)^2))) %>% 
      mutate(lower  = effect - qnorm(0.95) * se,
             upper  = effect + qnorm(0.95) * se)
    
    used_logLink <- (model1$family[[2]] %in% c("log","logit")) |
      grepl("Ordered Categorical", model1$family[[1]])
    legend_title <- ifelse(used_logLink, "Odds ratio", "Mean effect")
    y_trans      <- ifelse(used_logLink, "log", "identity")
    
    if (used_logLink) {
      plot_dat <- plot_dat %>% 
        mutate(exp_effect = exp(effect),
               exp_se     = sqrt((se^2) * (exp_effect^2))) %>% 
        mutate(exp_lower  = exp_effect - qnorm(0.975) * exp_se,
               exp_upper  = exp_effect + qnorm(0.975) * exp_se) %>% 
        select(-effect, -se, -upper, -lower) %>% 
        dplyr::rename(effect = exp_effect, se = exp_se,
                      upper  = exp_upper, lower = exp_lower)
    }
    
  }
  
  # filter the data
  if (!is.null(apc_range)) {
    if (!is.null(apc_range$age)) {
      plot_dat <- plot_dat %>% filter(age %in% apc_range$age)
    }
    if (!is.null(apc_range$period)) {
      plot_dat <- plot_dat %>% filter(period %in% apc_range$period)
    }
    if (!is.null(apc_range$cohort)) {
      plot_dat <- plot_dat %>% filter(cohort %in% apc_range$cohort)
    }
  }
  
  # bin the heatmap surface, if necessary
  if (!bin_heatmap) { # no binning
    
    plot_dat <- plot_dat %>% 
      dplyr::rename(plot_effect = effect, plot_upper = upper, plot_lower = lower)
    
  } else { # bin the heatmap
    
    # define the binning grid, if still necessary
    if (is.null(bin_heatmapGrid_list)) {
      bin_heatmapGrid_list <- list(seq(min(dat_predictionGrid[[dimensions[1]]], na.rm = TRUE) - 1,
                                       max(dat_predictionGrid[[dimensions[1]]], na.rm = TRUE),
                                       by = 5),
                                   seq(min(dat_predictionGrid[[dimensions[2]]], na.rm = TRUE) - 1,
                                       max(dat_predictionGrid[[dimensions[2]]], na.rm = TRUE),
                                       by = 5))
      names(bin_heatmapGrid_list) <- dimensions
    }
    
    dims_toBin       <- names(bin_heatmapGrid_list)
    dims_catVarNames <- paste0(dims_toBin, "_cat")
    
    for (i in 1:length(dims_toBin)) {
      plot_dat[[dims_catVarNames[i]]] <- cut(plot_dat[[dims_toBin[i]]],
                                             breaks = bin_heatmapGrid_list[[dims_toBin[i]]])
    }
    
    plot_dat <- plot_dat %>% 
      group_by_at(vars(dims_catVarNames)) %>% 
      mutate(plot_effect = mean(effect),
             plot_lower  = mean(lower),
             plot_upper  = mean(upper))
  }
  
  
  # create variables x, y, z additional to the APC variables, for easier handling
  plot_dat$x <- plot_dat[[dimensions[1]]]
  plot_dat$y <- plot_dat[[dimensions[2]]]
  dim_3 <- ifelse(!("age" %in% dimensions), "age",
                  ifelse(!("period" %in% dimensions), "period", "cohort"))
  plot_dat$z <- plot_dat[[dim_3]]
  
  x_lab <- capitalize_firstLetter(dimensions[1])
  y_lab <- capitalize_firstLetter(dimensions[2])
  
  
  # overall theme
  gg_theme <- theme(plot.title       = element_text(hjust = 0.5),
                    legend.position  = "bottom",
                    legend.key.width = unit(1.2, "cm"))
  
  # create the base heatmap plot
  gg_effect <- ggplot() +
    geom_tile(data = plot_dat, aes(x = x, y = y, fill = plot_effect)) +
    xlab(x_lab) + ylab(y_lab) + theme_minimal()
  
  if (!plot_CI) { # no confidence intervals to be plotted
    
    limits_color <- c(NA,NA)
    gg_list      <- list(gg_effect)
    
  } else { # add heatmaps for the confidence interval borders to the plot
    
    limits_color <- c(min(floor(plot_dat$plot_lower   * 1000) / 1000),
                      max(ceiling(plot_dat$plot_upper * 1000) / 1000))
    
    gg_lower <- ggplot() +
      geom_tile(data = plot_dat, aes(x = x, y = y, fill = plot_lower)) +
      ggtitle("Lower 95% CI boundary") + xlab(x_lab) + theme_minimal() +
      theme(axis.title.y = element_blank())
    gg_upper <- ggplot() +
      geom_tile(data = plot_dat, aes(x = x, y = y, fill = plot_upper)) +
      ggtitle("Upper 95% CI boundary") + xlab(x_lab) + theme_minimal() +
      theme(axis.title.y = element_blank())
    
    gg_list <- list(gg_effect, gg_lower, gg_upper)
  }
  
  round_3 <- function(x) {
    round(x, digits = 3)
  }
  
  # color scale
  scale_midpoint <- ifelse(!is.null(model1), 0, mean(plot_dat$plot_effect))
  gg_list <- lapply(gg_list, function(gg) {
    gg + scale_fill_gradient2(legend_title, limits = legend_limits,
                              breaks = c(0.6, 1, 1.5, 2.5, 4.5),
                              labels = round_3,
                              n.breaks = 8,
                              trans = y_trans, low = "dodgerblue4",
                              mid = "white", high = "firebrick3",
                              midpoint = scale_midpoint)
  })
  
  
  # mark some age groups / periods / cohorts in each plot
  if (!is.null(markLines_list)) {
    gg_list <- gg_addReferenceLines(gg_list                 = gg_list,
                                    dimensions              = dimensions,
                                    plot_dat                = plot_dat,
                                    markLines_list          = markLines_list,
                                    markLines_displayLabels = markLines_displayLabels)
  }
  
  
  # create final plot output
  if (!plot_CI) {
    plot <- gg_list[[1]]
    
  } else {
    plot <- ggpubr::ggarrange(plotlist      = gg_list,
                              legend        = "bottom",
                              common.legend = F,
                              ncol          = 3,
                              widths        = c(.34,.32,.32))
  }
  
  return(plot)
}



#' Internal helper to add reference lines in an APC heatmap
#' 
#' Internal helper function to add reference lines in an APC heatmap
#' (vertically, horizontally or diagonally). The function takes an existing list
#' of ggplot objects, adds the specified reference lines in each plot and
#' returns the edited ggplot list again. To be called from within
#' \code{\link{plot_APCheatmap}}.
#' 
#' @inheritParams plot_APCheatmap
#' @param gg_list Existing list of ggplot objects where the reference lines
#' should be marked in each individual ggplot.
#' @param plot_dat Dataset used for creating the heatmap.
#' 
#' @import dplyr ggplot2
#' 
gg_addReferenceLines <- function(gg_list, dimensions, plot_dat, markLines_list,
                                 markLines_displayLabels) {
  
  # some NULL definitions to appease CRAN checks regarding use of dplyr/ggplot2
  x <- y <- x_start <- x_end <- y_start <- y_end <- group <- NULL
  
  
  dim_3 <- ifelse(!("age" %in% dimensions), "age",
                  ifelse(!("period" %in% dimensions), "period", "cohort"))
  
  
  # add vertical lines
  if (dimensions[1] %in% names(markLines_list)) {
    gg_list <- lapply(gg_list, function(gg) {
      gg + geom_vline(xintercept = markLines_list[[dimensions[1]]],
                      col = gray(0.3), lty = 2)
    })
    
    # add labels
    if (dimensions[1] %in% markLines_displayLabels) {
      dim1Labels_dat <- data.frame(x = markLines_list[[dimensions[1]]],
                                   y = max(plot_dat$y))
      
      gg_list <- lapply(gg_list, function(gg) {
        gg + geom_label(data = dim1Labels_dat, aes(x = x, y = y, label = x),
                        hjust = 0, nudge_x = 1, nudge_y = 1)
      })
    }
  }
  
  # add horiztonal lines
  if (dimensions[2] %in% names(markLines_list)) {
    gg_list <- lapply(gg_list, function(gg) {
      gg + geom_hline(yintercept = markLines_list[[dimensions[2]]],
                      col = gray(0.3), lty = 2)
    })
    
    # add labels
    if (dimensions[2] %in% markLines_displayLabels) {
      dim2Labels_dat <- data.frame(x = max(plot_dat$x),
                                   y = markLines_list[[dimensions[2]]])
      
      gg_list <- lapply(gg_list, function(gg) {
        gg + geom_label(data = dim2Labels_dat, aes(x = x, y = y, label = y),
                        vjust = 0, nudge_x = 1, nudge_y = 1)
      })
    }
  }
  
  # add diagonal lines
  if (dim_3 %in% names(markLines_list)) {
    # create a dataset for the line segments
    dat_segments <- lapply(markLines_list[[dim_3]], function(z) {
      
      data.frame(x_start = case_when(dim_3 == "cohort" & dimensions[1] == "period" ~ min(plot_dat$y) + z,
                                     dim_3 == "cohort" & dimensions[1] == "age"    ~ min(plot_dat$y) - z,
                                     dim_3 == "period"                             ~ z - min(plot_dat$y),
                                     dim_3 == "age"    & dimensions[1] == "cohort" ~ min(plot_dat$y) - z,
                                     dim_3 == "age"    & dimensions[1] == "period" ~ min(plot_dat$y) + z),
                 x_end   = case_when(dim_3 == "cohort" & dimensions[1] == "period" ~ max(plot_dat$y) + z,
                                     dim_3 == "cohort" & dimensions[1] == "age"    ~ max(plot_dat$y) - z,
                                     dim_3 == "period"                             ~ z - max(plot_dat$y),
                                     dim_3 == "age"    & dimensions[1] == "cohort" ~ max(plot_dat$y) - z,
                                     dim_3 == "age"    & dimensions[1] == "period" ~ max(plot_dat$y) + z),
                 y_start = min(plot_dat$y),
                 y_end   = max(plot_dat$y),
                 group   = ifelse(match(z, markLines_list[[dim_3]]) == 1,
                                  paste(capitalize_firstLetter(dim_3), z),
                                  as.character(z)))
      
    }) %>% dplyr::bind_rows()
    
    # if necessary, cut each segment, s.t. it doesn't exceed the plot limits
    dat_segments <- ensure_segmentsInPlotRange(dat_segments, plot_dat)
    
    # add the segments to the plots
    gg_list <- lapply(gg_list, function(gg) {
      gg +
        geom_segment(data = dat_segments, aes(x = x_start, xend = x_end,
                                              y = y_start, yend = y_end,
                                              group = group),
                     col = gray(0.3), lty = 2)
    })
    
    # add labels
    if (dim_3 %in% markLines_displayLabels) {
      gg_list <- lapply(gg_list, function(gg) {
        gg + geom_label(data = dat_segments,
                        aes(x = x_end, y = y_end, label = group))
      })
    }
  }
  
  return(gg_list)
}



#' Internal helper for gg_addReferenceLines to keep diagonal lines in the plot range
#' 
#' Internal helper function to be called from within
#' \code{\link{gg_addReferenceLines}}. This function takes the dataset prepared
#' for adding diagonal reference lines in the plot, checks if some diagonals
#' exceed the plot limits, cuts them accordingly, if necessary, and again
#' returns the corrected dataset.
#' 
#' @inheritParams gg_addReferenceLines
#' @param dat_segments Dataset containing information on the diagonal reference
#' lines.
#' 
ensure_segmentsInPlotRange <- function(dat_segments, plot_dat) {
  
  x_range <- range(plot_dat$x)
  
  slopes <- (dat_segments$y_end - dat_segments$y_start) /
    (dat_segments$x_end - dat_segments$x_start)
  
  for (i in 1:nrow(dat_segments)) {
    
    dat_i <- dat_segments[i,]
    
    # since the lines in dat_segments sometimes have negative slope, flexibly
    # retrieve the start and end values from the columns
    x_start_var <- ifelse(dat_i$x_start < dat_i$x_end, "x_start", "x_end")
    y_start_var <- ifelse(dat_i$x_start < dat_i$x_end, "y_start", "y_end")
    x_end_var   <- ifelse(dat_i$x_end > dat_i$x_start, "x_end", "x_start")
    y_end_var   <- ifelse(dat_i$x_end > dat_i$x_start, "y_end", "y_start")
    
    # check the start of the line
    if (dat_i[[x_start_var]]  < x_range[1]) {
      dat_segments[i, x_start_var] <- x_range[1]
      dat_segments[i, y_start_var] <- dat_i[[y_end_var]] - slopes[i] * (dat_i[[x_end_var]] - x_range[1])
    }
    
    # check the end of the line
    if (dat_i[[x_end_var]] > x_range[2]) {
      dat_segments[i, x_end_var] <- x_range[2]
      dat_segments[i, y_end_var] <- dat_i[[y_start_var]] + slopes[i] * (x_range[2] - dat_i[[x_start_var]])
    }
  }
  
  return(dat_segments)
}
